{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Amazon EventBridge To Trigger a Pipeline Execution with S3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Amazon EventBridge is a serverless event bus that makes it easy to connect applications together using data from your own applications, integrated Software-as-a-Service (SaaS) applications, and AWS services.\n",
    "\n",
    "You can choose an event source (i.e. Amazon S3) and select a target from a number of AWS services including AWS Step Functions, AWS Lambda, Amazon SNS, and Amazon Kinesis Data Firehose. Amazon EventBridge will automatically deliver the events in near real-time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/automated_pipeline.png\" width=\"90%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sagemaker\n",
    "import logging\n",
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "import json\n",
    "from botocore.exceptions import ClientError\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name=\"sagemaker\", region_name=region)\n",
    "account_id = boto3.client(\"sts\").get_caller_identity().get(\"Account\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get the StepFunctions ARN and Name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r stepfunction_arn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    stepfunction_arn\n",
    "    print(\"[OK]\")\n",
    "except NameError:\n",
    "    print(\"+++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in this section before you continue.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stepfunction_arn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r stepfunction_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    stepfunction_name\n",
    "    print(\"[OK]\")\n",
    "except NameError:\n",
    "    print(\"+++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in this section before you continue.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stepfunction_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Steps\n",
    "1. Create S3 Buckets\n",
    "2. Enable CloudTrail Logging\n",
    "3. Get StepFunctions Pipeline\n",
    "4. Create EventBridge Rule\n",
    "5. Test Trigger"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create S3 Data Upload Bucket (watched) & S3 Bucket for CloudTrail Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "watched_bucket = \"dsoaws-test-upload-{}\".format(account_id)\n",
    "print(watched_bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 mb s3://$watched_bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls $watched_bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cloudtrail_bucket = \"cloudtrail-dsoaws-{}\".format(account_id)\n",
    "print(cloudtrail_bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 mb s3://$cloudtrail_bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls $cloudtrail_bucket"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attach an S3 Policy to the Cloud Trail ^^ Logging Bucket ^^ Above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = {\n",
    "    \"Version\": \"2012-10-17\",\n",
    "    \"Statement\": [\n",
    "        {\n",
    "            \"Sid\": \"AWSCloudTrailAclCheck20150319\",\n",
    "            \"Effect\": \"Allow\",\n",
    "            \"Principal\": {\"Service\": \"cloudtrail.amazonaws.com\"},\n",
    "            \"Action\": \"s3:GetBucketAcl\",\n",
    "            \"Resource\": \"arn:aws:s3:::{}\".format(cloudtrail_bucket),\n",
    "        },\n",
    "        {\n",
    "            \"Sid\": \"AWSCloudTrailWrite20150319\",\n",
    "            \"Effect\": \"Allow\",\n",
    "            \"Principal\": {\"Service\": \"cloudtrail.amazonaws.com\"},\n",
    "            \"Action\": \"s3:PutObject\",\n",
    "            \"Resource\": \"arn:aws:s3:::{}/AWSLogs/{}/*\".format(cloudtrail_bucket, account_id),\n",
    "            \"Condition\": {\"StringEquals\": {\"s3:x-amz-acl\": \"bucket-owner-full-control\"}},\n",
    "        },\n",
    "        {\n",
    "            \"Sid\": \"AWSCloudTrailHTTPSOnly20180329\",\n",
    "            \"Effect\": \"Deny\",\n",
    "            \"Principal\": {\"Service\": \"cloudtrail.amazonaws.com\"},\n",
    "            \"Action\": \"s3:*\",\n",
    "            \"Resource\": [\n",
    "                \"arn:aws:s3:::{}/AWSLogs/{}/*\".format(cloudtrail_bucket, account_id),\n",
    "                \"arn:aws:s3:::{}\".format(cloudtrail_bucket),\n",
    "            ],\n",
    "            \"Condition\": {\"Bool\": {\"aws:SecureTransport\": \"false\"}},\n",
    "        },\n",
    "    ],\n",
    "}\n",
    "\n",
    "print(policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy_json = json.dumps(policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"policy.json\", \"w\") as outfile:\n",
    "    json.dump(policy, outfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat policy.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3api put-bucket-policy --bucket $cloudtrail_bucket --policy file://policy.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Cloud Trail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cloudtrail = boto3.client(\"cloudtrail\")\n",
    "s3 = boto3.client(\"s3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trails = cloudtrail.describe_trails()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trails)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    t = cloudtrail.create_trail(Name=\"dsoaws\", S3BucketName=cloudtrail_bucket, IsMultiRegionTrail=True)\n",
    "    trail_name = t[\"Name\"]\n",
    "    trail_arn = t[\"TrailARN\"]\n",
    "    cloudtrail.start_logging(Name=trail_arn)\n",
    "    print(\"Cloud Trail created. Started logging.\")\n",
    "    print(\"--------------------------------------\")\n",
    "    print(\"New Trail name: {}\".format(trail_name))\n",
    "    print(\"New Trail arn: {}\".format(trail_arn))\n",
    "except ClientError as e:\n",
    "    if e.response[\"Error\"][\"Code\"] == \"TrailAlreadyExistsException\":\n",
    "        print(\"Trail already exists. This is OK.\")\n",
    "        print(\"------------------\")\n",
    "        t = cloudtrail.get_trail(Name=\"dsoaws\")\n",
    "        trail_name = t[\"Trail\"][\"Name\"]\n",
    "        trail_arn = t[\"Trail\"][\"TrailARN\"]\n",
    "        print(\"Trail name: {}\".format(trail_name))\n",
    "        print(\"Trail arn: {}\".format(trail_arn))\n",
    "    else:\n",
    "        print(\"Unexpected error: %s\" % e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Default EventBridge EventBus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "events = boto3.client(\"events\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = events.describe_event_bus(Name=\"default\")\n",
    "eventbus_arn = response[\"Arn\"]\n",
    "print(\"Bus {}\".format(eventbus_arn))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data Event Logging on CloudTrail for our S3 bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws cloudtrail list-trails"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws cloudtrail get-event-selectors --trail-name $trail_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "watched_bucket_arn = \"arn:aws:s3:::{}/\".format(watched_bucket)\n",
    "print(watched_bucket_arn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "event_selector = (\n",
    "    '\\'[{ \"ReadWriteType\": \"WriteOnly\", \"IncludeManagementEvents\":true, \"DataResources\": [{ \"Type\": \"AWS::S3::Object\", \"Values\": [\"'\n",
    "    + watched_bucket_arn\n",
    "    + \"\\\"] }] }]'\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(event_selector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws cloudtrail put-event-selectors --trail-name $trail_name --event-selectors $event_selector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Custom EventBridge Rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern = {\n",
    "    \"source\": [\"aws.s3\"],\n",
    "    \"detail-type\": [\"AWS API Call via CloudTrail\"],\n",
    "    \"detail\": {\n",
    "        \"eventSource\": [\"s3.amazonaws.com\"],\n",
    "        \"eventName\": [\"PutObject\", \"CompleteMultipartUpload\", \"CopyObject\"],\n",
    "        \"requestParameters\": {\"bucketName\": [\"{}\".format(watched_bucket)]},\n",
    "    },\n",
    "}\n",
    "\n",
    "pattern_json = json.dumps(pattern)\n",
    "print(pattern_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = events.put_rule(\n",
    "    Name=\"S3-Trigger\",\n",
    "    EventPattern=pattern_json,\n",
    "    State=\"ENABLED\",\n",
    "    Description=\"Triggers an event on S3 PUT\",\n",
    "    EventBusName=\"default\",\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rule_arn = response[\"RuleArn\"]\n",
    "print(rule_arn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add Target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create IAM Role"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iam = boto3.client(\"iam\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iam_role_name_eventbridge = \"DSOAWS_EventBridge_Invoke_StepFunctions\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create AssumeRolePolicyDocument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assume_role_policy_doc = {\n",
    "    \"Version\": \"2012-10-17\",\n",
    "    \"Statement\": [{\"Effect\": \"Allow\", \"Principal\": {\"Service\": \"events.amazonaws.com\"}, \"Action\": \"sts:AssumeRole\"}],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    iam_role_eventbridge = iam.create_role(\n",
    "        RoleName=iam_role_name_eventbridge,\n",
    "        AssumeRolePolicyDocument=json.dumps(assume_role_policy_doc),\n",
    "        Description=\"DSOAWS EventBridge Role\",\n",
    "    )\n",
    "except ClientError as e:\n",
    "    if e.response[\"Error\"][\"Code\"] == \"EntityAlreadyExists\":\n",
    "        print(\"Role already exists\")\n",
    "    else:\n",
    "        print(\"Unexpected error: %s\" % e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the Role ARN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "role_eventbridge = iam.get_role(RoleName=iam_role_name_eventbridge)\n",
    "iam_role_eventbridge_arn = role_eventbridge[\"Role\"][\"Arn\"]\n",
    "print(iam_role_eventbridge_arn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Eventbridge Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eventbridge_sfn_policy = {\n",
    "    \"Version\": \"2012-10-17\",\n",
    "    \"Statement\": [{\"Sid\": \"VisualEditor0\", \"Effect\": \"Allow\", \"Action\": \"states:StartExecution\", \"Resource\": \"*\"}],\n",
    "}\n",
    "\n",
    "print(eventbridge_sfn_policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Policy Object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    policy_eventbridge_sfn = iam.create_policy(\n",
    "        PolicyName=\"DSOAWS_EventBridgeInvokeStepFunction\", PolicyDocument=json.dumps(eventbridge_sfn_policy)\n",
    "    )\n",
    "    print(\"Done.\")\n",
    "except ClientError as e:\n",
    "    if e.response[\"Error\"][\"Code\"] == \"EntityAlreadyExists\":\n",
    "        print(\"Policy already exists\")\n",
    "        policy_eventbridge_sfn_arn = f\"arn:aws:iam::{account_id}:policy/DSOAWS_EventBridgeInvokeStepFunction\"\n",
    "        iam.create_policy_version(\n",
    "            PolicyArn=policy_eventbridge_sfn_arn, PolicyDocument=json.dumps(eventbridge_sfn_policy), SetAsDefault=True\n",
    "        )\n",
    "        print(\"Policy updated.\")\n",
    "    else:\n",
    "        print(\"Unexpected error: %s\" % e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get ARN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy_eventbridge_sfn_arn = f\"arn:aws:iam::{account_id}:policy/DSOAWS_EventBridgeInvokeStepFunction\"\n",
    "print(policy_eventbridge_sfn_arn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attach Policy To Role"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    response = iam.attach_role_policy(PolicyArn=policy_eventbridge_sfn_arn, RoleName=iam_role_name_eventbridge)\n",
    "    print(\"Done.\")\n",
    "except ClientError as e:\n",
    "    if e.response[\"Error\"][\"Code\"] == \"EntityAlreadyExists\":\n",
    "        print(\"Policy is already attached. This is ok.\")\n",
    "    else:\n",
    "        print(\"Unexpected error: %s\" % e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup EventBridge Rule Target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sfn = boto3.client(\"stepfunctions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Model Pipeline Inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "timestamp = int(time.time())\n",
    "\n",
    "execution_name = \"run-{}\".format(timestamp)\n",
    "print(execution_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the Raw Inputs S3 Location\n",
    "TODO:  Change this to the watched input location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_input_data_s3_uri = \"s3://{}/amazon-reviews-pds/tsv/\".format(bucket)\n",
    "print(raw_input_data_s3_uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set the Processing Hyper-Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 64\n",
    "train_split_percentage = 0.90\n",
    "validation_split_percentage = 0.05\n",
    "test_split_percentage = 0.05\n",
    "balance_dataset = True\n",
    "processing_instance_count = 2\n",
    "processing_instance_type = \"ml.c5.2xlarge\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Training Hyper-Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1\n",
    "learning_rate = 0.00001\n",
    "epsilon = 0.00000001\n",
    "train_batch_size = 128\n",
    "validation_batch_size = 128\n",
    "test_batch_size = 128\n",
    "train_steps_per_epoch = 100\n",
    "validation_steps = 100\n",
    "test_steps = 100\n",
    "train_instance_count = 1\n",
    "train_instance_type = \"ml.c5.9xlarge\"\n",
    "train_volume_size = 1024\n",
    "use_xla = True\n",
    "use_amp = True\n",
    "freeze_bert_layer = False\n",
    "enable_sagemaker_debugger = False\n",
    "enable_checkpointing = False\n",
    "enable_tensorboard = False\n",
    "input_mode = \"File\"\n",
    "run_validation = True\n",
    "run_test = True\n",
    "run_sample_predictions = True\n",
    "deploy_instance_count = 1\n",
    "# deploy_instance_type='ml.m5.4xlarge'\n",
    "deploy_instance_type = \"ml.m5.large\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note:  Below, we are re-using the `sourcedir.tar.gz` (contains `tf_bert_reviews.py`) uploaded during a previous notebook's `sagemaker.estimator.TensorFlow.fit()` invocation.  We could manually copy the source to an S3 location and use this for the location of the `sourcedir.tar.gz`, but we choose to re-use for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processing_code_s3_prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processing_code_s3_prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find the AWS ECR account which hosts the scikit-learn docker image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You find the regional AWS ECR account IDs storing the docker images here:\n",
    "# https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-docker-containers-frameworks.html\n",
    "account_id_scikit_learn_image_us_east_1 = \"683313688378\"\n",
    "account_id_scikit_learn_image_us_west_2 = \"246618743249\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "account_id_scikit_learn_image = \"\"\n",
    "if region == \"us-east-1\":\n",
    "    account_id_scikit_learn_image = account_id_scikit_learn_image_us_east_1\n",
    "elif region == \"us-west-2\":\n",
    "    account_id_scikit_learn_image = account_id_scikit_learn_image_us_west_2\n",
    "else:\n",
    "    print(\"Please look up the correct AWS ECR Account ID per Link above.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(account_id_scikit_learn_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = {\n",
    "    \"Processing Job\": {\n",
    "        \"ProcessingJobName\": \"training-pipeline-{}\".format(execution_name),\n",
    "        \"ProcessingInputs\": [\n",
    "            {\n",
    "                \"InputName\": \"raw_input\",\n",
    "                \"S3Input\": {\n",
    "                    # TODO:  Change to watched_bucket + watched_s3_prefix\n",
    "                    #           \"S3Uri\": \"s3://{}/{}/\".format(watched_bucket, watched_s3_prefix),\n",
    "                    \"S3Uri\": \"{}\".format(raw_input_data_s3_uri),\n",
    "                    \"LocalPath\": \"/opt/ml/processing/input/data/\",\n",
    "                    \"S3DataType\": \"S3Prefix\",\n",
    "                    \"S3InputMode\": \"File\",\n",
    "                    \"S3DataDistributionType\": \"ShardedByS3Key\",\n",
    "                    \"S3CompressionType\": \"None\",\n",
    "                },\n",
    "            },\n",
    "            {\n",
    "                \"InputName\": \"code\",\n",
    "                \"S3Input\": {\n",
    "                    \"S3Uri\": \"s3://{}/{}/preprocess-scikit-text-to-bert.py\".format(bucket, processing_code_s3_prefix),\n",
    "                    \"LocalPath\": \"/opt/ml/processing/input/code\",\n",
    "                    \"S3DataType\": \"S3Prefix\",\n",
    "                    \"S3InputMode\": \"File\",\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                    \"S3CompressionType\": \"None\",\n",
    "                },\n",
    "            },\n",
    "        ],\n",
    "        \"ProcessingOutputConfig\": {\n",
    "            \"Outputs\": [\n",
    "                {\n",
    "                    \"OutputName\": \"bert-train\",\n",
    "                    \"S3Output\": {\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-train\".format(bucket, execution_name),\n",
    "                        \"LocalPath\": \"/opt/ml/processing/output/bert/train\",\n",
    "                        \"S3UploadMode\": \"EndOfJob\",\n",
    "                    },\n",
    "                },\n",
    "                {\n",
    "                    \"OutputName\": \"bert-validation\",\n",
    "                    \"S3Output\": {\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-validation\".format(bucket, execution_name),\n",
    "                        \"LocalPath\": \"/opt/ml/processing/output/bert/validation\",\n",
    "                        \"S3UploadMode\": \"EndOfJob\",\n",
    "                    },\n",
    "                },\n",
    "                {\n",
    "                    \"OutputName\": \"bert-test\",\n",
    "                    \"S3Output\": {\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-test\".format(bucket, execution_name),\n",
    "                        \"LocalPath\": \"/opt/ml/processing/output/bert/test\",\n",
    "                        \"S3UploadMode\": \"EndOfJob\",\n",
    "                    },\n",
    "                },\n",
    "            ]\n",
    "        },\n",
    "        \"AppSpecification\": {\n",
    "            \"ImageUri\": \"{}.dkr.ecr.{}.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3\".format(\n",
    "                account_id_scikit_learn_image, region\n",
    "            ),\n",
    "            \"ContainerArguments\": [\n",
    "                \"--train-split-percentage\",\n",
    "                \"{}\".format(train_split_percentage),\n",
    "                \"--validation-split-percentage\",\n",
    "                \"{}\".format(validation_split_percentage),\n",
    "                \"--test-split-percentage\",\n",
    "                \"{}\".format(test_split_percentage),\n",
    "                \"--max-seq-length\",\n",
    "                \"{}\".format(max_seq_length),\n",
    "                \"--balance-dataset\",\n",
    "                \"{}\".format(balance_dataset),\n",
    "            ],\n",
    "            \"ContainerEntrypoint\": [\"python3\", \"/opt/ml/processing/input/code/preprocess-scikit-text-to-bert.py\"],\n",
    "        },\n",
    "        \"RoleArn\": \"{}\".format(role),\n",
    "        \"ProcessingResources\": {\n",
    "            \"ClusterConfig\": {\n",
    "                \"InstanceCount\": processing_instance_count,\n",
    "                \"InstanceType\": \"{}\".format(processing_instance_type),\n",
    "                \"VolumeSizeInGB\": 30,\n",
    "            }\n",
    "        },\n",
    "        \"StoppingCondition\": {\"MaxRuntimeInSeconds\": 7200},\n",
    "    },\n",
    "    \"Training\": {\n",
    "        \"AlgorithmSpecification\": {\n",
    "            \"TrainingImage\": \"763104351884.dkr.ecr.{}.amazonaws.com/tensorflow-training:2.1.0-cpu-py36-ubuntu18.04\".format(\n",
    "                region\n",
    "            ),\n",
    "            \"TrainingInputMode\": \"{}\".format(input_mode),\n",
    "        },\n",
    "        \"OutputDataConfig\": {\"S3OutputPath\": \"s3://{}/training-pipeline-{}/models\".format(bucket, execution_name)},\n",
    "        \"StoppingCondition\": {\"MaxRuntimeInSeconds\": 7200},\n",
    "        \"ResourceConfig\": {\n",
    "            \"InstanceCount\": train_instance_count,\n",
    "            \"InstanceType\": \"{}\".format(train_instance_type),\n",
    "            \"VolumeSizeInGB\": train_volume_size,\n",
    "        },\n",
    "        \"RoleArn\": \"{}\".format(role),\n",
    "        \"InputDataConfig\": [\n",
    "            {\n",
    "                \"DataSource\": {\n",
    "                    \"S3DataSource\": {\n",
    "                        \"S3DataType\": \"S3Prefix\",\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-train\".format(bucket, execution_name),\n",
    "                        \"S3DataDistributionType\": \"ShardedByS3Key\",\n",
    "                    }\n",
    "                },\n",
    "                \"ChannelName\": \"train\",\n",
    "            },\n",
    "            {\n",
    "                \"DataSource\": {\n",
    "                    \"S3DataSource\": {\n",
    "                        \"S3DataType\": \"S3Prefix\",\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-validation\".format(bucket, execution_name),\n",
    "                        \"S3DataDistributionType\": \"ShardedByS3Key\",\n",
    "                    }\n",
    "                },\n",
    "                \"ChannelName\": \"validation\",\n",
    "            },\n",
    "            {\n",
    "                \"DataSource\": {\n",
    "                    \"S3DataSource\": {\n",
    "                        \"S3DataType\": \"S3Prefix\",\n",
    "                        \"S3Uri\": \"s3://{}/{}/processing/output/bert-test\".format(bucket, execution_name),\n",
    "                        \"S3DataDistributionType\": \"ShardedByS3Key\",\n",
    "                    }\n",
    "                },\n",
    "                \"ChannelName\": \"test\",\n",
    "            },\n",
    "        ],\n",
    "        \"HyperParameters\": {\n",
    "            \"epochs\": \"{}\".format(epochs),\n",
    "            \"learning_rate\": \"{}\".format(learning_rate),\n",
    "            \"epsilon\": \"{}\".format(epsilon),\n",
    "            \"train_batch_size\": \"{}\".format(train_batch_size),\n",
    "            \"validation_batch_size\": \"{}\".format(validation_batch_size),\n",
    "            \"test_batch_size\": \"{}\".format(test_batch_size),\n",
    "            \"train_steps_per_epoch\": \"{}\".format(train_steps_per_epoch),\n",
    "            \"validation_steps\": \"{}\".format(validation_steps),\n",
    "            \"test_steps\": \"{}\".format(test_steps),\n",
    "            \"use_xla\": \"{}\".format(str(use_xla).lower()),\n",
    "            \"use_amp\": \"{}\".format(str(use_amp).lower()),\n",
    "            \"max_seq_length\": \"{}\".format(max_seq_length),\n",
    "            \"freeze_bert_layer\": \"{}\".format(str(freeze_bert_layer).lower()),\n",
    "            \"enable_sagemaker_debugger\": \"{}\".format(str(enable_sagemaker_debugger).lower()),\n",
    "            \"enable_checkpointing\": \"{}\".format(str(enable_checkpointing).lower()),\n",
    "            \"enable_tensorboard\": \"{}\".format(str(enable_tensorboard).lower()),\n",
    "            \"run_validation\": \"{}\".format(str(run_validation).lower()),\n",
    "            \"run_test\": \"{}\".format(str(run_test).lower()),\n",
    "            \"run_sample_predictions\": \"{}\".format(str(run_sample_predictions).lower()),\n",
    "            \"sagemaker_submit_directory\": '\"s3://{}/{}/estimator-source/source/sourcedir.tar.gz\"'.format(\n",
    "                bucket, stepfunction_name\n",
    "            ),\n",
    "            \"sagemaker_program\": '\"tf_bert_reviews.py\"',\n",
    "            \"sagemaker_enable_cloudwatch_metrics\": \"false\",\n",
    "            \"sagemaker_container_log_level\": \"20\",\n",
    "            \"sagemaker_job_name\": '\"training-pipeline-{}/estimator-source\"'.format(execution_name),\n",
    "            \"sagemaker_region\": '\"{}\"'.format(region),\n",
    "            \"model_dir\": '\"s3://{}/training-pipeline-{}/estimator-source/model\"'.format(bucket, execution_name),\n",
    "        },\n",
    "        \"TrainingJobName\": \"estimator-training-pipeline-{}\".format(execution_name),\n",
    "        \"DebugHookConfig\": {\"S3OutputPath\": \"s3://{}/\".format(bucket)},\n",
    "    },\n",
    "    \"Create Model\": {\n",
    "        \"ModelName\": \"training-pipeline-{}\".format(execution_name),\n",
    "        \"PrimaryContainer\": {\n",
    "            \"Image\": \"763104351884.dkr.ecr.{}.amazonaws.com/tensorflow-inference:2.1.0-cpu-py36-ubuntu18.04\".format(\n",
    "                region\n",
    "            ),\n",
    "            \"Environment\": {\n",
    "                \"SAGEMAKER_PROGRAM\": \"null\",\n",
    "                \"SAGEMAKER_SUBMIT_DIRECTORY\": \"null\",\n",
    "                \"SAGEMAKER_ENABLE_CLOUDWATCH_METRICS\": \"false\",\n",
    "                \"SAGEMAKER_CONTAINER_LOG_LEVEL\": \"20\",\n",
    "                \"SAGEMAKER_REGION\": \"{}\".format(region),\n",
    "            },\n",
    "            \"ModelDataUrl\": \"s3://{}/training-pipeline-{}/models/estimator-training-pipeline-{}/output/model.tar.gz\".format(\n",
    "                bucket, execution_name, execution_name\n",
    "            ),\n",
    "        },\n",
    "        \"ExecutionRoleArn\": \"{}\".format(role),\n",
    "    },\n",
    "    \"Configure Endpoint\": {\n",
    "        \"EndpointConfigName\": \"training-pipeline-{}\".format(execution_name),\n",
    "        \"ProductionVariants\": [\n",
    "            {\n",
    "                \"InitialInstanceCount\": deploy_instance_count,\n",
    "                \"InstanceType\": \"{}\".format(deploy_instance_type),\n",
    "                \"ModelName\": \"training-pipeline-{}\".format(execution_name),\n",
    "                \"VariantName\": \"AllTraffic\",\n",
    "            }\n",
    "        ],\n",
    "    },\n",
    "    \"Deploy\": {\n",
    "        \"EndpointConfigName\": \"training-pipeline-{}\".format(execution_name),\n",
    "        \"EndpointName\": \"training-pipeline-{}\".format(execution_name),\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs_json = json.dumps(inputs)\n",
    "\n",
    "print(inputs_json)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create EventBridge Rule Target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check for exsting targets\n",
    "targets = events.list_targets_by_rule(Rule=\"S3-Trigger\", EventBusName=\"default\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_targets = len(targets[\"Targets\"])\n",
    "\n",
    "if number_targets > 0:\n",
    "    for target in targets[\"Targets\"]:\n",
    "        print(target[\"Id\"])\n",
    "        events.remove_targets(Rule=\"S3-Trigger\", EventBusName=\"default\", Ids=[target[\"Id\"]], Force=True)\n",
    "    print(\"Target: \" + target[\"Id\"] + \" removed.\")\n",
    "else:\n",
    "    print(\"No targets defined yet.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "\n",
    "target_id = str(uuid.uuid4())\n",
    "\n",
    "response = events.put_targets(\n",
    "    Rule=\"S3-Trigger\",\n",
    "    EventBusName=\"default\",\n",
    "    Targets=[{\"Id\": target_id, \"Arn\": stepfunction_arn, \"RoleArn\": iam_role_eventbridge_arn, \"Input\": inputs_json}],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Number of StepFunction Invocations **Before** the S3 Trigger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "execution_list_before_uploading = sfn.list_executions(stateMachineArn=stepfunction_arn)\n",
    "\n",
    "number_of_executions_before_uploading = len(execution_list_before_uploading[\"executions\"])\n",
    "\n",
    "print(number_of_executions_before_uploading)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Upload to S3 and Trigger a StepFunction Invocation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "time.sleep(15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "watched_s3_uri = \"s3://{}/watched_input/\".format(watched_bucket)\n",
    "\n",
    "print('Uploading training data to \"{}\" to trigger a new training pipeline.'.format(watched_s3_uri))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp ./data-tfrecord/bert-train/part-algo-1-amazon_reviews_us_Digital_Software_v1_00.tfrecord $watched_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time.sleep(30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp ./data-tfrecord/bert-train/part-algo-1-amazon_reviews_us_Digital_Software_v1_00.tfrecord $watched_s3_uri"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Number of StepFunction Invocations **After** the S3 Trigger (Wait for 60 seconds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time.sleep(60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "execution_list_after_uploading = sfn.list_executions(stateMachineArn=stepfunction_arn)\n",
    "\n",
    "print(execution_list_after_uploading)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_of_executions_after_uploading = len(execution_list_after_uploading[\"executions\"])\n",
    "\n",
    "print(number_of_executions_after_uploading)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_execution = execution_list_after_uploading[\"executions\"][0]\n",
    "\n",
    "current_execution_arn = current_execution[\"executionArn\"]\n",
    "\n",
    "print(current_execution_arn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/states/home?region={}#/executions/details/{}\">Step Functions Pipeline</a></b>'.format(\n",
    "            region, current_execution_arn\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.save_checkpoint()\n",
    "Jupyter.notebook.session.delete();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
